# ------------------------------------------------------------------------------
# Generates model predictions in val and test set for split models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 05_predict_subscores.sh {dataset}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(xgboost)
library(optparse)
library(testit) # assert
library(here) # here() relative filepaths

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--dataset", type = "character"))
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)

# Helpers ----------------------------------------------------------------------
get_score_name <- function(model_file) {
  glue("p__{str_replace(model_file, '.rds', '')}")
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
dataset <- arg_list$dataset

x <- readRDS(glue(paths$features[[dataset]]))
x <- u$sparsify(x)

if(dataset == "val"){
  ids <- readRDS(file.path(temp, "split_val_cohort.rds"))
}else{
  ids <- readRDS(
    file.path(
      paths$modeling$dir, "cohorts", "random", "test_cohort.rds"
    )
  )
}
setDT(ids)

lasso_files <- list.files(temp) %>%
  str_subset("lasso__")
gbm_files <- list.files(temp) %>%
  str_subset("gbm__") %>%
  str_subset("tuning", negate = TRUE)

# Predict GBM ------------------------------------------------------------------
message("Predicting with GBMs...")
gbm_models <- map(
  file.path(temp, gbm_files),
  xgb.load
)
walk2(
  get_score_name(gbm_files), gbm_models,
  ~ ids[, (.x) := predict(.y, x, outputmargin = FALSE)]
)

# Predict Lasso ----------------------------------------------------------------
message("Predicting with lasso...")
lasso_models <- map(
  file.path(temp, lasso_files),
  readRDS
)
walk2(
  get_score_name(lasso_files), lasso_models,
  ~ ids[, (.x) := predict(.y, x, s = "lambda.min", type = "response")]
)

# Save Data --------------------------------------------------------------------
message("Saving results...")
write_rds(ids, file.path(temp, glue("subscores_{dataset}_set.rds")))

# Done -------------------------------------------------------------------------
message("Done.")
